/*
 * @copyright (c) 2024, MacRsh
 *
 * @license SPDX-License-Identifier: Apache-2.0
 *
 * @date 2024-02-18    MacRsh       First version
 */

#include "module.h"

static struct module root = {.name = "root", .public ={&root, &root}};

static void start(void)
{

}
MODULE_EXPORT(start, "0.end");

static void end(void)
{

}
MODULE_EXPORT(end, "256.end");

/**
 * This function handles exports the module.
 */
void module_export_handle(void)
{
    for (const module_export_fn_t *fn = &_module_export_start; fn < &_module_export_end; fn++) {
        (*fn)();
    }
}

/**
 * This function allocates memory.
 *
 * @param size The size of the memory.
 *
 * @return The pointer to the allocated memory.
 */
MODULE_WEAK void *module_malloc(size_t size)
{
    return malloc(size);
}

/**
 * This function releases memory.
 *
 * @param memory The pointer to the memory to be released.
 */
MODULE_WEAK void module_free(void *memory)
{
    free(memory);
}

/**
 * This function finds an exported module.
 *
 * @param name The name of the module.
 *
 * @return The pointer to the module.
 */
struct module *module_find(const char *name)
{
    for (struct module *module = root.public.next; module != &root; module = module->public.next) {
        if (strcmp(module->name, name) == 0) {
            return module;
        }
    }
    return NULL;
}

/**
 * This function initializes a module.
 *
 * @param module The pointer to the module.
 * @param name The name of the module.
 * @param resource The pointer to the module dependency.
 * @param ops The pointer to the module operations.
 *
 * @return The result of the initialization.
 */
int module_init(struct module *module,
                const char *name,
                struct module_resource *resource,
                const struct module_ops *ops)
{
    static struct module_resource null_resource = {NULL};
    static struct module_ops null_ops = {NULL};

    if ((!module) || (!name) || (!name[0])) {
        return -EINVAL;
    }

    module->name = name;
    module->ref = 0;
    module->public.prev = module->public.next = module;
    module->resource = resource ? resource : &null_resource;
    module->ops = ops ? ops : &null_ops;
    module->pos = -1;
    module->io.input.module = NULL;
    module->io.input.next = NULL;
    module->io.output.module = NULL;
    module->io.output.next = NULL;
    return 0;
}

/**
 * This function gets the module resource.
 *
 * @param module The pointer to the module.
 *
 * @return The pointer to the module resource.
 */
struct module_resource *module_get_resource(struct module *module)
{
    if (!module) {
        return NULL;
    }

    return module->resource;
}

/**
 * This function public a module.
 *
 * @param module The pointer to the module.
 *
 * @return The result of the operation.
 */
int module_public(struct module *module)
{
    if (!module) {
        return -EINVAL;
    }
    if ((module->public.next != module) || (module->public.prev != module)) {
        return -ENOENT;
    }
    if (module_find(module->name)) {
        return -EEXIST;
    }

    root.public.prev->public.next = module;
    module->public.prev = root.public.prev;
    root.public.prev = module;
    module->public.next = &root;
    return 0;
}

/**
 * This function unpublic a module.
 *
 * @param module The pointer to the module.
 *
 * @return The result of the operation.
 */
int module_unpublic(struct module *module)
{
    if (!module) {
        return -EINVAL;
    }
    if ((module->public.next == module) || (module->public.prev == module)) {
        return -ENOENT;
    }

    module->public.next->public.prev = module->public.prev;
    module->public.prev->public.next = module->public.next;
    module->public.prev = module->public.next = module;
    return 0;
}

static int reference_module_load(struct reference *reference, struct module *module)
{
    /* Module initialization is called if the module is loaded for the first time
     * and the module supports initialization */
    if ((module->ref == 0) && (module->ops->init)) {
        int ret = module->ops->init(module);
        if (ret < 0) {
            return ret;
        }
    }

    module->ref++;
    reference->module = module;
    reference->next = NULL;
    return 0;
}

static int reference_module_unload(struct reference *prev,
                                   struct reference *reference,
                                   struct module *module)
{
    /* If this is the last time the module is unloaded and the module supports de-initialization,
     * module de-initialization is called */
    if ((module->ref == 1) && (module->ops->deinit)) {
        int ret = module->ops->deinit(module);
        if (ret < 0) {
            return ret;
        }
    }

    /* If it is the first reference that cannot be removed,
     * the next reference is used as the first reference */
    if (prev == reference) {
        reference = reference->next;
        if (reference) {
            prev->module = reference->module;
        } else {
            prev->module = NULL;
        }
    }
    /* If the reference can be removed normally, remove it */
    if (reference) {
        if (reference->next) {
            prev->next = reference->next;
        } else {
            prev->next = NULL;
        }
        module_free(reference);
    }

    module->ref--;
    return 0;
}

/**
 * This function loads an input module.
 *
 * @param module The pointer to the module.
 * @param input The pointer to the input module.
 *
 * @return The result of the loading.
 */
int module_load_input(struct module *module, struct module *input)
{
    if ((!module) || (!input) || (!input->ops->input)) {
        return -EINVAL;
    }

    for (struct reference *reference = &module->io.input;
         reference != NULL;
         reference = reference->next) {

        /* Multiple references to the same module are not supported,
         * because it is impossible to determine which module needs
         * to be unloaded when a reference is removed */
        if (reference->module == input) {
            return -EEXIST;
        }

        /* As the first reference, the module is loaded normally */
        if (reference->module == NULL) {
            return reference_module_load(reference, input);
        }

        /* Subsequent references need to create a new reference link after the first reference,
         * because the first reference is used to track the base module information */
        if (reference->next == NULL) {
            reference->next = module_malloc(sizeof(struct reference));
            if (!reference->next) {
                return -ENOMEM;
            }
            return reference_module_load(reference->next, input);
        }
    }
    return 0;
}

/**
 * This function loads an output module.
 *
 * @param module The pointer to the module.
 * @param output The pointer to the output module.
 *
 * @return The result of the loading.
 */
int module_load_output(struct module *module, struct module *output)
{
    if ((!module) || (!output) || (!output->ops->output)) {
        return -EINVAL;
    }

    for (struct reference *reference = &module->io.output;
         reference != NULL;
         reference = reference->next) {
        if (reference->module == output) {
            return -EEXIST;
        }

        if (reference->module == NULL) {
            return reference_module_load(reference, output);
        }

        if (reference->next == NULL) {
            reference->next = module_malloc(sizeof(struct reference));
            if (!reference->next) {
                return -ENOMEM;
            }
            return reference_module_load(reference->next, output);
        }
    }
    return 0;
}

/**
 * This function unloads an input module.
 *
 * @param module The pointer to the module.
 * @param input The pointer to the input module.
 *
 * @return The result of the unloading.
 */
int module_unload_input(struct module *module, struct module *input)
{
    if ((!module) || (!input) || (!input->ops->input)) {
        return -EINVAL;
    }

    for (struct reference *prev = &module->io.input, *reference = &module->io.input;
         reference != NULL;
         prev = reference, reference = reference->next) {
        if (reference->module == input) {
            return reference_module_unload(prev, reference, input);
        }
    }
    return -ENOENT;
}

/**
 * This function unloads an output module.
 *
 * @param module The pointer to the module.
 * @param output The pointer to the output module.
 *
 * @return The result of the unloading.
 */
int module_unload_output(struct module *module, struct module *output)
{
    if ((!module) || (!output) || (!output->ops->output)) {
        return -EINVAL;
    }

    for (struct reference *prev = &module->io.output, *reference = &module->io.output;
         reference != NULL;
         prev = reference, reference = reference->next) {
        if (reference->module == output) {
            return reference_module_unload(prev, reference, output);
        }
    }
    return -ENOENT;
}

static ssize_t reference_input(struct reference *reference, int pos, void *buf, size_t count)
{
    if (reference->next) {
        ssize_t ret = reference_input(reference->next, pos, buf, count);
        if (ret <= 0) {
            return ret;
        }
        count = ret;
    }
    return reference->module->ops->input(reference->module, pos, buf, count);
}

/**
 * This function reads data from an module.
 *
 * @param module The pointer to the module.
 * @param buf The pointer to the buffer.
 * @param count The size of the buffer.
 *
 * @return The number of the actual bytes read, otherwise an error code.
 */
ssize_t module_read(struct module *module, void *buf, size_t count)
{
    if ((!module) || ((!buf) && (count != 0))) {
        return -EINVAL;
    }
    if (!module->io.input.module) {
        return -ENOSYS;
    }

    return reference_input(&module->io.input, module->pos, buf, count);
}

static ssize_t reference_output(struct reference *reference, int pos, const void *buf, size_t count)
{
    ssize_t ret = reference->module->ops->output(reference->module, pos, buf, count);
    if (ret <= 0) {
        return ret;
    }
    if (reference->next) {
        return reference_output(reference->next, pos, buf, ret);
    }
    return ret;
}

/**
 * This function writes data to an module.
 *
 * @param module The pointer to the module.
 * @param buf The pointer to the buffer.
 * @param count The size of the buffer.
 *
 * @return The number of the actual bytes written, otherwise an error code.
 */
ssize_t module_write(struct module *module, const void *buf, size_t count)
{
    if ((!module) || ((!buf) && (count != 0))) {
        return -EINVAL;
    }
    if (!module->io.output.module) {
        return -ENOSYS;
    }

    return reference_output(&module->io.output, module->pos, buf, count);
}

#define MODULE_IOCTL_MASK(_cmd) \
    (((_cmd >> 28) & 0x9) == 0x9)                               /**< command identifier, maximum 4bit: 1001 */
#define MODULE_IOCTL_IO(_cmd)       ((_cmd >> 24) & 0x1)        /**< 0: output, 1: input */
#define MODULE_IOCTL_TIER(_cmd)     ((_cmd >> 16) & 0xff)       /**< module level that the command targets */
#define MODULE_IOCTL_CMD(_cmd)      (_cmd & 0xffff)             /**< command code */

static int reference_ioctl(struct reference *reference, int pos, int cmd, void *arg)
{
    for (size_t i = 0; i < MODULE_IOCTL_TIER(cmd); i++) {
        if (!reference->next) {
            return -ENOENT;
        } else {
            reference = reference->next;
        }
    }

    return reference->module->ops->ioctl(reference->module, pos, MODULE_IOCTL_CMD(cmd), arg);
}

/**
 * This function controls an module.
 *
 * @param module The pointer to the module.
 * @param cmd The command.
 * @param arg The argument.
 *
 * @return The result of the operation.
 *
 * @note The command needs to use the MODULE_IOCTL(io, tier, cmd) macro,
 *       where io is 0(output) or 1(input),
 *       tier is an 8-bit value that specifies the module level that the command targets.
 *       cmd is a 16-bit command code,
 */
int module_ioctl(struct module *module, int cmd, void *args)
{
    struct reference *reference;

    if ((!module) || (!MODULE_IOCTL_MASK(cmd)) || (!MODULE_IOCTL_CMD(cmd))) {
        return -EINVAL;
    }

    reference = MODULE_IOCTL_IO(cmd) ? &module->io.input : &module->io.output;
    return reference_ioctl(reference, module->pos, cmd, args);
}

/**
 * This function seeks an module.
 *
 * @param module The pointer to the module.
 * @param offset The offset to seek.
 * @param whence The position to seek.
 *
 * @return The new position, otherwise an error code.
 */
int module_seek(struct module *module, int offset, int whence)
{
    if ((!module) ||
        ((whence != MODULE_SEEK_SET) &&
         (whence != MODULE_SEEK_CUR) &&
         (whence != MODULE_SEEK_END))) {
        return -EINVAL;
    }

    if (whence == MODULE_SEEK_SET) {
        module->pos = offset;
    } else if (whence == MODULE_SEEK_CUR) {
        module->pos += offset;
    } else {
        module->pos = -1 + offset;
    }
    return module->pos;
}
